Skip to content

ensure consistent topic and topic tag when classifying #6673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions kitsune/questions/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from copy import deepcopy
from unittest.mock import patch

from django.contrib.contenttypes.models import ContentType
from parameterized import parameterized

from kitsune.flagit.models import FlaggedObject
from kitsune.llm.questions.classifiers import ModerationAction
from kitsune.products.tests import TopicFactory
from kitsune.questions.models import Answer, Question
from kitsune.questions.tests import AnswerFactory, QuestionFactory
from kitsune.questions.utils import (
Expand All @@ -10,10 +15,12 @@
num_answers,
num_questions,
num_solutions,
process_classification_result,
remove_pii,
remove_home_dir_pii,
)
from kitsune.sumo.tests import TestCase
from kitsune.users.models import Profile
from kitsune.users.tests import UserFactory


Expand Down Expand Up @@ -209,3 +216,205 @@ def test_remove_pii(self):
] = "C:\\Users\\<USERNAME>\\AppData\\Local\\Mozilla\\Firefox"
remove_pii(data)
self.assertDictEqual(data, expected)


class ProcessClassificationResultTests(TestCase):

def setUp(self):
self.topic1 = TopicFactory()
self.topic2 = TopicFactory()
self.sumo_bot = Profile.get_sumo_bot()

def test_spam_result(self):
question = QuestionFactory(topic=self.topic1)
classification_result = dict(
action=ModerationAction.SPAM,
)
self.assertFalse(question.is_spam)
self.assertIsNone(question.marked_as_spam)
self.assertIsNone(question.marked_as_spam_by)
self.assertEqual(question.topic, self.topic1)

process_classification_result(question, classification_result)

question.refresh_from_db()

self.assertTrue(question.is_spam)
self.assertIsNotNone(question.marked_as_spam)
self.assertEqual(question.marked_as_spam_by, self.sumo_bot)

def test_flagged_result(self):
question = QuestionFactory(topic=self.topic1)
classification_result = dict(
action=ModerationAction.FLAG_REVIEW,
spam_result=dict(reason="I think it is spam?"),
)

q_ct = ContentType.objects.get_for_model(question)

self.assertFalse(question.is_spam)
self.assertFalse(
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
)
self.assertEqual(question.topic, self.topic1)

process_classification_result(question, classification_result)

question.refresh_from_db()

self.assertFalse(question.is_spam)
self.assertEqual(question.topic, self.topic1)
self.assertTrue(
FlaggedObject.objects.filter(
content_type=q_ct,
object_id=question.id,
creator=self.sumo_bot,
reason=FlaggedObject.REASON_SPAM,
status=FlaggedObject.FLAG_PENDING,
notes__contains="I think it is spam?",
).exists()
)

def test_topic_result_with_change(self):
question = QuestionFactory(topic=self.topic1, tags=[self.topic1.slug])
classification_result = dict(
action=ModerationAction.NOT_SPAM,
topic_result=dict(
topic=self.topic2.title,
reason="Dude, it is so topic2.",
),
)

q_ct = ContentType.objects.get_for_model(question)

self.assertFalse(question.is_spam)
self.assertFalse(
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
)
self.assertEqual(question.topic, self.topic1)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})

process_classification_result(question, classification_result)

question.refresh_from_db()

self.assertFalse(question.is_spam)
self.assertEqual(question.topic, self.topic2)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic2.slug})
self.assertTrue(
FlaggedObject.objects.filter(
content_type=q_ct,
object_id=question.id,
creator=self.sumo_bot,
status=FlaggedObject.FLAG_ACCEPTED,
reason=FlaggedObject.REASON_CONTENT_MODERATION,
notes__contains="Dude, it is so topic2.",
).exists()
)

def test_topic_result_with_no_initial_topic(self):
question = QuestionFactory(topic=None)
classification_result = dict(
action=ModerationAction.NOT_SPAM,
topic_result=dict(
topic=self.topic2.title,
reason="Dude, it is so topic2.",
),
)

q_ct = ContentType.objects.get_for_model(question)

self.assertFalse(question.is_spam)
self.assertFalse(
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
)
self.assertIsNone(question.topic)
self.assertFalse(question.my_tags)

process_classification_result(question, classification_result)

question.refresh_from_db()

self.assertFalse(question.is_spam)
self.assertEqual(question.topic, self.topic2)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic2.slug})
self.assertTrue(
FlaggedObject.objects.filter(
content_type=q_ct,
object_id=question.id,
creator=self.sumo_bot,
status=FlaggedObject.FLAG_ACCEPTED,
reason=FlaggedObject.REASON_CONTENT_MODERATION,
notes__contains="Dude, it is so topic2.",
).exists()
)

def test_topic_result_with_no_change(self):
question = QuestionFactory(topic=self.topic1, tags=[self.topic1.slug])
classification_result = dict(
action=ModerationAction.NOT_SPAM,
topic_result=dict(
topic=self.topic1.title,
reason="Dude, it is so topic1.",
),
)

q_ct = ContentType.objects.get_for_model(question)

self.assertFalse(question.is_spam)
self.assertFalse(
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
)
self.assertEqual(question.topic, self.topic1)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})

process_classification_result(question, classification_result)

question.refresh_from_db()

self.assertFalse(question.is_spam)
self.assertEqual(question.topic, self.topic1)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})
self.assertTrue(
FlaggedObject.objects.filter(
content_type=q_ct,
object_id=question.id,
creator=self.sumo_bot,
status=FlaggedObject.FLAG_ACCEPTED,
reason=FlaggedObject.REASON_CONTENT_MODERATION,
notes__contains="Dude, it is so topic1.",
).exists()
)

def test_topic_result_with_incomplete_transaction(self):
question = QuestionFactory(topic=self.topic1, tags=[self.topic1.slug])
classification_result = dict(
action=ModerationAction.NOT_SPAM,
topic_result=dict(
topic=self.topic2.title,
reason="Dude, it is so topic2.",
),
)

q_ct = ContentType.objects.get_for_model(question)

self.assertFalse(question.is_spam)
self.assertFalse(
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
)
self.assertEqual(question.topic, self.topic1)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})

with patch.object(question, "save", side_effect=Exception):
with self.assertRaises(Exception):
process_classification_result(question, classification_result)

question.refresh_from_db()

# Since one of the DB changes failed, they all should be rolled back.
self.assertFalse(question.is_spam)
self.assertEqual(question.topic, self.topic1)
self.assertEqual(set(tag.name for tag in question.my_tags), {self.topic1.slug})
self.assertFalse(
FlaggedObject.objects.filter(content_type=q_ct, object_id=question.id).exists()
)
41 changes: 26 additions & 15 deletions kitsune/questions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.contrib.sessions.backends.base import SessionBase
from django.db import transaction

from kitsune.flagit.models import FlaggedObject
from kitsune.llm.questions.classifiers import ModerationAction
from kitsune.products.models import Product, Topic
from kitsune.questions.models import Answer, Question

# from kitsune.tags.models import SumoTag
from kitsune.users.models import Profile
from kitsune.wiki.utils import get_featured_articles as kb_get_featured_articles
from kitsune.wiki.utils import has_visited_kb
Expand Down Expand Up @@ -186,20 +189,28 @@ def process_classification_result(
reason=FlaggedObject.REASON_SPAM,
)
case _:
if topic_title := result["topic_result"].get("topic"):
try:
topic = Topic.active.get(title=topic_title, visible=True)
except (Topic.DoesNotExist, Topic.MultipleObjectsReturned):
return
else:
flag_question(
question,
by_user=sumo_bot,
notes=(
"LLM classified as {topic.title}, for the following reason:\n"
f"{result['topic_result']['reason']}"
),
status=FlaggedObject.FLAG_ACCEPTED,
)
if not (topic_title := result["topic_result"].get("topic")):
return

try:
topic = Topic.active.get(title=topic_title, visible=True)
except (Topic.DoesNotExist, Topic.MultipleObjectsReturned):
return

with transaction.atomic():
flag_question(
question,
by_user=sumo_bot,
notes=(
"LLM classified as {topic.title}, for the following reason:\n"
f"{result['topic_result']['reason']}"
),
status=FlaggedObject.FLAG_ACCEPTED,
)
if topic != question.topic:
if question.topic:
question.tags.remove(question.topic.slug)
question.topic = topic
question.save()
question.tags.add(topic.slug)
question.clear_cached_tags()